Phase Mask Design¶
In this notebook, we will illustrate the problem of inverse design of a phase mask: we will choose the example from Wong et al, 2021, designing a diffractive pupil phase mask for the Toliman telescope.
In order to get high precision centroids, we need to maximize the gradient energy of the pupil; in order to satisfy fabrication constraints, we need a binary mask with phases only in {0, π}.
import jax
import jax.numpy as np
from jax import vmap
import matplotlib.pyplot as plt
from tqdm import tqdm
%matplotlib inline
plt.rcParams['image.cmap'] = 'hot'
plt.rcParams["text.usetex"] = 'false'
plt.rcParams['figure.dpi'] = 120
/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/lib/__init__.py:34: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
We will first generate an orthonormal basis for the pupil phases, and then threshold this to {0, 1} while preserving soft edges using the Continuous Latent Image Mask Binarization (CLIMB) algorithm from the Wong et al paper.
from sklearn.decomposition import PCA
Generate the support of the pupil:
wf_npix = 256
oversample = 3
nslice = 3
npix = wf_npix * oversample
c = (npix - 1) / 2.
xs = (np.arange(npix) - c) / c
XX, YY = np.meshgrid(xs, xs)
RR = np.sqrt(XX ** 2 + YY ** 2)
PHI = np.arctan2(YY, XX)
mask = np.logical_and(RR <= 1, RR >= 0.175).astype(float)
plt.imshow(mask)
plt.colorbar()
plt.show()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Generate basis vectors however you like - in this case we are using logarithmic radial harmonics and sines and cosines in θ., but you can do whatever you like here. This code is not important; just generate your favourite not-necessarily-orthonormal basis, and we will use PCA to orthonormalize it later on.
a = 10
b = 8
ith = 10
As = np.arange(-a, a+1)
Bs = 3 * np.arange(0, b+1)
Cs = np.array([-np.pi/2, np.pi/2])
Is = np.arange(-ith, ith+1)
LRHF_fn = lambda A, B, C, RR, PHI: np.cos(A*np.log(RR + 1e-12) + B*PHI + C)
sine_fn = lambda i, RR: np.sin(i * np.pi * RR)
cose_fn = lambda i, RR: np.cos(i * np.pi * RR)
gen_LRHF_basis = vmap(vmap(vmap(LRHF_fn, (None, 0, None, None, None)), (0, None, None, None, None)), (None, None, 0, None, None))
gen_sine_basis = vmap(sine_fn, in_axes=(0, None))
gen_cose_basis = vmap(cose_fn, in_axes=(0, None))
LRHF_basis = gen_LRHF_basis(As, Bs, Cs, RR, PHI).reshape([len(As)*len(Bs)*len(Cs), npix, npix])
sine_basis = gen_sine_basis(Is, RR)
cose_basis = gen_cose_basis(Is, RR)
LRHF_flat = LRHF_basis.reshape([len(As)*len(Bs)*len(Cs), npix*npix])
sine_flat = sine_basis.reshape([len(sine_basis), npix*npix])
cose_flat = cose_basis.reshape([len(cose_basis), npix*npix])
full_basis = np.concatenate([
LRHF_flat,
sine_flat,
cose_flat
])
Orthonormalize with PCA - could also use Gram-Schmidt if you prefer.
%%time
pca = PCA().fit(full_basis)
components = pca.components_.reshape([len(full_basis), npix, npix])
components = np.copy(components[:99,:,:])
basis = np.concatenate([np.mean(pca.mean_)*np.array(np.ones((1,npix,npix))), components])
CPU times: user 2min 36s, sys: 1.19 s, total: 2min 37s Wall time: 29.5 s
Show the pretty basis vectors:
nfigs = 100
ncols = 10
nrows = 1 + nfigs//ncols
plt.figure(figsize=(4*ncols, 4*nrows))
for i in range(nfigs):
plt.subplot(nrows, ncols, i+1)
plt.imshow(basis[i], cmap='seismic')
plt.xticks([])
plt.yticks([])
plt.tight_layout()